import itertools
import json
import pathlib
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from warnings import warn

import cachetools
import dirtyjson
import numpy as np
import pandas as pd
import tqdm.auto
from beartype import beartype
from typing_extensions import Self

from .queries import next_token_logprobs, next_token_logprobs_cache_info
from .token_logprobs import (  # LOGPROB_CUTOFF_CLASSIFICATION,; LOGPROB_CUTOFF_RELATIVE_PRECISION,
    ABSENT_TOKEN_LOGPROB,
    LOGPROB_CUTOFF_RELATIVE_PRECISION,
    TokenLogprobs,
    TokenSet,
    logsumexp,
)

CACHE_PromptPair_next_token_logprobs_with_alpha = cachetools.LRUCache(maxsize=1024)
CACHE_PromptPair_logprob_of_continuation = cachetools.LRUCache(maxsize=1024)
CACHE_PromptPair_logprob_of_correct_continuation = cachetools.LRUCache(maxsize=1024)
CACHE_PromptPair_evaluate = cachetools.LRUCache(maxsize=1024)

LIMIT_QUERY_TOKENS = 100
SUBTASK_PREFIX_TOKENS = [" ", "\n"]


@beartype
def nonempty_prefixes(s: str):
    for i in range(len(s)):
        yield s[: i + 1]


@beartype
@dataclass
class PromptPair:
    full_prompt: str
    weak_prompt: str
    correct_continuation: str
    correct_continuation_tokens: tuple[tuple[str, ...], ...] = None
    wrong_continuations: tuple[str, ...] = field(default_factory=tuple)
    wrong_continuations_tokens: tuple[tuple[str, ...], ...] = None
    classification_tokenset: TokenSet = None
    weakening_name: str = None
    file_name: str = None
    task_number: int = None
    subtasks: dict[str, Self] = None
    ambiguous_prefixes: tuple[str, ...] = None

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, PromptPair):
            return False
        return self.hash_key() == other.hash_key()

    def __hash__(self):
        return hash(self.hash_key())

    def hash_key(self):
        return (
            self.full_prompt,
            self.weak_prompt,
            self.correct_continuation,
            self.wrong_continuations,
            self.file_name,
            self.task_number,
        )

    @property
    def is_classification(self):
        return len(self.wrong_continuations) > 0

    @classmethod
    def from_task_json(
        cls, json_line: str, weak_map=None, weakening_name=None, file_name=None, task_number=None, subtasks_recurse=1
    ) -> Self:
        data = json.loads(json_line)
        wp = data["prompt"]
        if weak_map is not None:
            wp = weak_map(wp)
        r = cls(
            full_prompt=data["prompt"],
            weak_prompt=wp,
            correct_continuation="",
            weakening_name=weakening_name,
            file_name=file_name,
            task_number=task_number,
        )
        if "completion" in data:
            # Continuation probability task
            r.correct_continuation = data["completion"]
        elif "classes" in data and "answer_index" in data:
            # Classification task
            classes = data["classes"]
            if isinstance(classes, str):
                classes = dirtyjson.loads(classes)
            if not isinstance(classes, list):
                raise ValueError(f"Invalid `classes` in file {file_name!r}: {classes!r}")
            index = data["answer_index"]
            r.correct_continuation = classes[index]
            r.wrong_continuations = tuple(classes[i] for i in range(len(classes)) if i != index)
            assert len(r.wrong_continuations) == len(classes) - 1
        else:
            raise ValueError(f"Invalid task json in file {file_name!r}, keys present: {data.keys()}")
        r.compute_classification_tokens()
        r.compute_subtasks(recurse=subtasks_recurse)
        return r

    def to_json_common(self):
        return {
            "full_prompt": self.full_prompt,
            "weakened_prompt": self.weak_prompt,
            "correct_continuation": self.correct_continuation,
            "wrong_continuations": self.wrong_continuations,
        }

    # @cachetools.cachedmethod(
    #     lambda _: CACHE_PromptPair_next_token_logprobs_with_alpha, key=cachetools.keys.typedkey
    # )
    # def next_token_logprobs_with_alpha(
    #     self, model: str, alpha: float | np.floating, suffix: str = ""
    # ) -> TokenLogprobs:
    #     """
    #     alpha=-1.0 -> weak only
    #     alpha=0.0 -> full only
    #     """
    #     ntl_f = next_token_logprobs(self.full_prompt + suffix, model)
    #     ntl_w = next_token_logprobs(self.weak_prompt + suffix, model)
    #     return ntl_w.interpolate_to(ntl_f, other_weight=alpha + 1.0)

    def compute_subtasks(self, recurse=1):
        if self.subtasks is not None:
            return
        self.subtasks = {}
        if recurse <= 0:
            return
        assert self.ambiguous_prefixes is not None, "run compute_classification_tokens first"
        for t in SUBTASK_PREFIX_TOKENS:
            if t not in self.ambiguous_prefixes:
                continue
            s = PromptPair(
                full_prompt=self.full_prompt + t,
                weak_prompt=self.weak_prompt + t,
                correct_continuation=self.correct_continuation[len(t) :],
                wrong_continuations=tuple(wc[len(t) :] for wc in self.wrong_continuations),
                weakening_name=self.weakening_name,
                file_name=self.file_name,
                task_number=self.task_number,
            )
            s.compute_classification_tokens()
            s.compute_subtasks(recurse=recurse - 1)
            self.subtasks[t] = s

    def compute_classification_tokens(self):
        correct_prefixes = list(nonempty_prefixes(self.correct_continuation))
        wrong_prefixes = list(itertools.chain(*(nonempty_prefixes(w) for w in self.wrong_continuations)))
        self.classification_tokenset = TokenSet(correct_prefixes + wrong_prefixes, make_unique=True, sort=True)
        self.ambiguous_prefixes = tuple(set(correct_prefixes).intersection(wrong_prefixes + SUBTASK_PREFIX_TOKENS))

        if self.correct_continuation_tokens is None:
            self.correct_continuation_tokens = tuple(
                (cp,)
                for cp in nonempty_prefixes(self.correct_continuation)
                if cp not in wrong_prefixes and cp not in self.ambiguous_prefixes
            )

        if self.wrong_continuations_tokens is None:
            self.wrong_continuations_tokens = tuple(
                itertools.chain(
                    *[
                        [
                            (wp,)
                            for wp in nonempty_prefixes(wc)
                            if wp not in correct_prefixes and wp not in self.ambiguous_prefixes
                        ]
                        for wc in self.wrong_continuations
                    ]
                )
            )

    # @cachetools.cachedmethod(
    #     lambda _: CACHE_PromptPair_logprob_of_continuation, key=cachetools.keys.typedkey
    # )
    # def logprob_of_continuation(
    #     self,
    #     continuation: str,
    #     *,
    #     model: str,
    #     alpha: float | np.floating,
    #     temperature=None,
    #     suffix: str = "",
    #     strict_end=False,
    #     logprob_cutoff=LOGPROB_CUTOFF_CLASSIFICATION,
    #     logprob_precision=LOGPROB_CUTOFF_RELATIVE_PRECISION,
    # ):
    #     # print(f"logprob_of_continuation {suffix!r} {continuation!r} {model!r} a={alpha} t={temperature} lpcutoff={logprob_cutoff}")
    #     if len(continuation) == 0:
    #         return 0.0  # Already there
    #     if logprob_cutoff > 0.0:
    #         return ABSENT_TOKEN_LOGPROB  # Concrete value number actually irrelevant, p=0.0 here
    #     ntl = self.next_token_logprobs_with_alpha(model, alpha, suffix=suffix)
    #     ntl = ntl.with_temperature(temperature)
    #     relevant_idxs = np.argsort(ntl.logprobs)[::-1]
    #     if len(ntl) > 100:
    #         relevant_idxs = relevant_idxs[(ntl.logprobs[relevant_idxs] >= logprob_cutoff)]

    #     logprob_correct = -np.inf
    #     for idx in relevant_idxs:
    #         token = ntl.tokens[idx]
    #         logprob = ntl.logprobs[idx]
    #         if logprob < logprob_correct + logprob_precision:
    #             break
    #         if continuation.startswith(token):
    #             cont_lp = self.logprob_of_continuation(
    #                 continuation[len(token) :],
    #                 model=model,
    #                 alpha=alpha,
    #                 temperature=temperature,
    #                 suffix=suffix + token,
    #                 strict_end=strict_end,
    #                 logprob_cutoff=np.maximum(
    #                     logprob_cutoff - logprob, logprob_correct + logprob_precision - logprob
    #                 ),
    #             )
    #             logprob_correct = np.logaddexp(logprob_correct, logprob + cont_lp)
    #         elif not strict_end and token.startswith(continuation):
    #             logprob_correct = np.logaddexp(logprob_correct, logprob + 0.0)
    #     # print(f"{continuation!r} -> {p_correct}")
    #     return logprob_correct

    # @cachetools.cachedmethod(
    #     lambda _: CACHE_PromptPair_logprob_of_correct_continuation, key=cachetools.keys.typedkey
    # )
    # def logprob_of_correct_continuation(
    #     self,
    #     *,
    #     model: str,
    #     alpha: float | np.floating,
    #     temperature=None,
    #     logprob_cutoff=None,
    #     strict_end=False,
    # ) -> float:
    #     if logprob_cutoff is None:
    #         if self.wrong_continuations:
    #             logprob_cutoff = LOGPROB_CUTOFF_CLASSIFICATION
    #         else:
    #             logprob_cutoff = LOGPROB_CUTOFF_RELATIVE_PRECISION
    #     logprob_correct = self.logprob_of_continuation(
    #         self.correct_continuation,
    #         model=model,
    #         alpha=alpha,
    #         temperature=temperature,
    #         strict_end=strict_end,
    #         logprob_cutoff=logprob_cutoff,
    #     )
    #     if not self.wrong_continuations:
    #         return logprob_correct  # We are just computing correct contiunuation probability
    #     if logprob_correct <= LOGPROB_CUTOFF_CLASSIFICATION:
    #         return LOGPROB_CUTOFF_CLASSIFICATION  # NB: This could also return NaN if logprob_wrong == -np.inf

    #     # We are doing classification
    #     logprobs_wrong = [
    #         self.logprob_of_continuation(
    #             cont,
    #             model=model,
    #             alpha=alpha,
    #             temperature=temperature,
    #             strict_end=strict_end,
    #             logprob_cutoff=np.maximum(
    #                 logprob_cutoff, logprob_correct + LOGPROB_CUTOFF_RELATIVE_PRECISION
    #             ),
    #         )
    #         for cont in self.wrong_continuations
    #     ]
    #     logprobs_wrong = [x for x in logprobs_wrong if x > LOGPROB_CUTOFF_CLASSIFICATION]
    #     if not logprobs_wrong:  # All are smalles, and prob_correct is non-zero
    #         return 0.0

    #     logprob_wrong = logsumexp(np.array(logprobs_wrong))
    #     lp_res = logprob_correct - np.logaddexp(logprob_correct, logprob_wrong)
    #     # print(logprob_correct, logprob_wrong, logprobs_wrong, lp_res)
    #     return lp_res

    def evaluate(
        self, model: str, alphas, temperatures, top_ks=(None,), normalize=True, model_weak=None
    ) -> pd.DataFrame:
        alphas, temperatures, top_ks = [
            tuple(x) if isinstance(x, (list, tuple, np.ndarray)) else (x,) for x in [alphas, temperatures, top_ks]
        ]
        if model_weak is None:
            model_weak = model
        return self._evaluate_task_pair_cached(
            model, alphas, temperatures, top_ks=top_ks, normalize=normalize, model_weak=model_weak
        )

    @cachetools.cachedmethod(lambda _: CACHE_PromptPair_evaluate, key=cachetools.keys.typedkey)
    def _evaluate_task_pair_cached(
        self,
        model: str,
        alphas,
        temperatures,
        top_ks=(None,),
        normalize=True,
        model_weak=None,
    ) -> pd.DataFrame:
        if model_weak is None:
            model_weak = model

        res = []
        restrict_to = self.classification_tokenset.tokens if self.is_classification else None
        tl_f = next_token_logprobs(
            self.full_prompt,
            model,
            top_n=LIMIT_QUERY_TOKENS,
            restrict_to=restrict_to,
        )
        tl_w = next_token_logprobs(
            self.weak_prompt,
            model_weak,
            top_n=LIMIT_QUERY_TOKENS,
            restrict_to=restrict_to,
        )
        tvd = tl_f.tvd(tl_w)

        # print(f"{model} {self.weak_prompt!r} {tl_w}")
        model_tokens = set(TokenSet.from_model(model).tokens)
        assert set(TokenSet.from_model(model_weak).tokens) == model_tokens, "Two models need to have the same token set"
        considered_tokens = set(nonempty_prefixes(self.correct_continuation))
        for wc in self.wrong_continuations:
            considered_tokens |= set(nonempty_prefixes(wc))
        considered_tokens = considered_tokens.intersection(model_tokens)
        # Sort by len decreasing, so that longer tokens are considered first
        considered_tokens = tuple(sorted(considered_tokens, key=lambda x: len(x), reverse=True))

        # Bit of a relict, multi-tokens now solved by subtasks
        for ct in self.correct_continuation_tokens:
            assert len(ct) == 1, f"correct_continuation_token {ct!r} is longer than 1 token (unsupported)"
        for wt in self.wrong_continuations_tokens:
            assert len(wt) == 1, f"wrong_continuation_token {wt!r} is longer than 1 token (unsupported)"

        for a, t, top_k in itertools.product(alphas, temperatures, top_ks):
            tl_f2 = tl_f
            # Limit full prompt to top k tokens, if given
            if top_k is not None:
                tl_f2 = tl_f2.top_tokens(top_n=top_k)

            if a is not None:
                # Main method: interpolation by alpha
                tl = tl_f2.extrapolate_from_weakened(tl_w, alpha=a)
            else:
                # Alternative method: use just the difference
                tl = TokenLogprobs(tl_f2.tokenset, tl_f2.logprobs - tl_w.logprobs_for_tokenset(tl_f2.tokens))
            # print(f"alpha={a:.3}  {tl.top_tokens(5)}")
            # Temperature 1.0 need no temperature adjustment, skip it
            tl = tl.with_temperature(t)

            # print(
            #     f"...{self.full_prompt[-10:]!r} l={len(self.full_prompt)} a={a} t={t} top: {tl.top_tokens(3)}"
            # )

            lp_correct, lp_wrong, lp_unknown = (
                ABSENT_TOKEN_LOGPROB,
                ABSENT_TOKEN_LOGPROB,
                ABSENT_TOKEN_LOGPROB,
            )

            # In case non-present upper-bounded elements raked up too much logprob mass
            lp_considered = np.array([tl[tok] for tok in considered_tokens])
            assert (
                np.sum(np.exp(lp_considered)) <= 2.0
            ), f"{np.sum(np.exp(lp_considered))=} {lp_considered=} {considered_tokens=}"
            if logsumexp(lp_considered) > 0.0:
                lp_considered -= logsumexp(lp_considered)
            assert np.sum(np.exp(lp_considered)) <= 1.001, f"{np.sum(np.exp(lp_considered))=} {lp_considered=}"

            # print(f"Possible: {possible_tokens}")
            for tok, lp in zip(considered_tokens, lp_considered):
                if (tok,) in self.correct_continuation_tokens:
                    lp_correct = np.logaddexp(lp, lp_correct)
                    # print(f"add {tok!r} {lp} to correct")
                elif self.wrong_continuations_tokens and (tok,) in self.wrong_continuations_tokens:
                    lp_wrong = np.logaddexp(lp, lp_wrong)
                    # print(f"add {tok!r} {lp} to wrong")
                else:
                    if self.subtasks and tok in self.subtasks:
                        if lp < np.logaddexp(lp_correct, lp_wrong) + LOGPROB_CUTOFF_RELATIVE_PRECISION:
                            # Irrelevant probability relative to already known correct and wrong
                            lp_unknown = np.logaddexp(lp, lp_unknown)
                        else:
                            lp_subtask = self.subtasks[tok].evaluate(
                                model, alphas, temperatures, top_ks, normalize=False, model_weak=model_weak
                            )
                            mask = (lp_subtask.alpha == a) & (lp_subtask.temperature == t)
                            if top_k is not None:
                                mask &= lp_subtask.top_k == top_k
                            else:
                                mask &= lp_subtask.top_k.isna()
                            lp_subtask = lp_subtask[mask].reset_index()
                            assert len(lp_subtask) == 1
                            lp_correct = np.logaddexp(lp_correct, lp_subtask["lp_correct"][0] + lp)
                            lp_wrong = np.logaddexp(lp_wrong, lp_subtask["lp_wrong"][0] + lp)
                            lp_unknown = np.logaddexp(lp_unknown, lp_subtask["lp_unknown"][0] + lp)
                    # Only applies if token is in prefixes of correct or wrong continuation AND not unambiguously correct or wrong AND not a subtask
                    else:
                        if self.is_classification:
                            lp_unknown = np.logaddexp(lp, lp_unknown)
                        else:
                            if tok in self.ambiguous_prefixes:
                                lp_unknown = np.logaddexp(lp, lp_unknown)
                            else:
                                lp_wrong = np.logaddexp(lp, lp_wrong)
                        # print(f"add {tok!r} {lp} to unknown")

            if self.is_classification:
                if normalize:
                    tot = logsumexp([lp_correct, lp_wrong, lp_unknown])
                    tot = np.maximum(tot, LOGPROB_CUTOFF_RELATIVE_PRECISION)
                    lp_correct = lp_correct - tot
                    lp_wrong = lp_wrong - tot
                    lp_unknown = lp_unknown - tot
                    if not np.isclose(tot, LOGPROB_CUTOFF_RELATIVE_PRECISION):
                        assert np.isclose(
                            np.sum(np.exp([lp_correct, lp_wrong, lp_unknown])), 1.0, atol=1e-3
                        ), f"{np.sum(np.exp([lp_correct, lp_wrong, lp_unknown]))=}"
            else:
                lp_wrong = np.log(
                    np.maximum(1.0 - np.exp(lp_correct) - np.exp(lp_unknown), np.exp(ABSENT_TOKEN_LOGPROB))
                )
            lp_correct = float(lp_correct)
            assert lp_correct <= 1e-6
            lp_wrong = float(lp_wrong)
            assert lp_wrong <= 1e-6
            lp_unknown = float(lp_unknown)
            assert lp_unknown <= 1e-6

            res.append(
                dict(
                    model=model,
                    model_weak=model_weak,
                    alpha=float(a),
                    temperature=float(t),
                    top_k=top_k,
                    weakening=self.weakening_name,
                    lp_correct=lp_correct,
                    p_correct=float(np.exp(lp_correct)),
                    lp_wrong=lp_wrong,
                    p_wrong=float(np.exp(lp_wrong)),
                    lp_unknown=lp_unknown,
                    p_unknown=float(np.exp(lp_unknown)),
                    file_name=self.file_name,
                    task_number=self.task_number,
                    tvd=tvd,
                )
            )
        assert res
        return pd.DataFrame(res)


@beartype
def evaluate_task_pairs(task_pairs, models, alphas, temperatures, progress=False) -> pd.DataFrame:
    if isinstance(models, str):
        models = [models]
    if isinstance(models, tuple):
        raise "Tuples not allowed as a `models` to avoid ambiguity - pass in a str or list of strs or pairs of strs"
    it = list(itertools.product(models, task_pairs))
    if progress:
        it = tqdm.auto.tqdm(it, desc="evaluate_task_pairs", miniters=1)
        cache0 = np.array(next_token_logprobs.cache_info())
    res = []

    def update_progress():
        if progress:
            cache1 = np.array(next_token_logprobs.cache_info())
            cdiff = cache1 - cache0
            it.set_postfix_str(f" {cdiff[1]} LLM calls + {cdiff[0]} cached calls, current model: {model}")
            it.refresh()

    for model, task_pair in it:
        if isinstance(model, tuple):
            model, model_weak = model
        else:
            model_weak = model
        update_progress()
        res.append(task_pair.evaluate(model, alphas, temperatures, model_weak=model_weak))
    update_progress()
    return pd.concat(res).reset_index(drop=True)


@beartype
def load_prompt_pairs(
    path: str, weak_map: Callable[[str], str] = None, weakening_name: str = None, limit: int = None, subtasks_recurse=1
) -> list[PromptPair]:
    with open(path) as f:
        return [
            PromptPair.from_task_json(
                line,
                weak_map=weak_map,
                weakening_name=weakening_name,
                file_name=pathlib.Path(path).name,
                task_number=i,
                subtasks_recurse=subtasks_recurse,
            )
            for i, line in enumerate(f.readlines())
            if limit is None or i < limit
        ]
